from torch.utils.data import DataLoader

from src.utils import MetPolArgs
from src.dataset import MetPolDataset


def get_dataloader_and_statics(flag='train', args=MetPolArgs()):
    shuffle = True if flag =='train' else False
    dataset = MetPolDataset(flag=flag, lead_time=args.lead_time)
    dataloader = DataLoader(dataset, shuffle=shuffle, **args.data.__dict__)
    statics = {
        "coord_info": dataloader.dataset.site_coords,
        "map_order": dataloader.dataset.map_order,
        'target_mean': dataloader.dataset.target_mean,
        'target_std': dataloader.dataset.target_std
    }
    
    return dataloader, statics
